import asyncio

from py_pli.pylib import VUnits
from py_pli.pylib import GlobalVar

from config_enum import measurement_unit_enum as meas_config

from virtualunits.HAL import HAL
from virtualunits.vu_node_application import VUNodeApplication
from virtualunits.vu_measurement_unit import VUMeasurementUnit
from virtualunits.VirtualTemperatureUnit import VirtualTemperatureUnit
from virtualunits.meas_seq_generator import meas_seq_generator
from virtualunits.meas_seq_generator import TriggerSignal
from virtualunits.meas_seq_generator import OutputSignal
from virtualunits.meas_seq_generator import MeasurementChannel
from virtualunits.meas_seq_generator import IntegratorMode
from virtualunits.meas_seq_generator import AnalogControlMode

from predefined_tasks.common.helper import send_to_gc

from fleming.common.node_io import EEFAnalogInput

hal_unit: HAL = VUnits.instance.hal
eef_unit: VUNodeApplication = hal_unit.nodes['EEFNode']
fmb_unit: VUNodeApplication = hal_unit.nodes['Mainboard']
meas_unit: VUMeasurementUnit = hal_unit.measurementUnit
pmt1_cooling: VirtualTemperatureUnit = hal_unit.pmt_ch1_Cooling

#TODO:
# - Analog vs Counting validation
# - Check CPS conversion

# Flash Lum Measurement Sequence #######################################################################################

async def load_flash_lum_operation(op_id, meas_time_ms, analog_meas_time_ms=0.1, analog_dead_time_ms=0.5, repeats=1, interval_ms=0.0):
    sequence = flash_lum_measurement(meas_time_ms, analog_meas_time_ms, analog_dead_time_ms, repeats, interval_ms)
    meas_unit.resultAddresses[op_id] = range(0, (repeats * 6))
    await meas_unit.LoadTriggerSequence(op_id, sequence)


def flash_lum_measurement(meas_time_ms, analog_meas_time_ms, analog_dead_time_ms, repeats, interval_ms):
    if (meas_time_ms < analog_meas_time_ms):
        raise ValueError(f"meas_time_ms must be greater or equal to analog_meas_time_ms")
    if (meas_time_ms > 10000):
        # Limiting the meas_time_ms to 10s allows using 32 bit results for counting and analog.
        # Counting: (2^32 - 1) / 10s = ~429e6 CPS is greater than 250e6 estimated max CPS
        # Analog:   10000ms / (0.02ms + 0.5ms) = ~19231 is smaller than 2^32 / 2^16 = 2^16 = 65536
        raise ValueError(f"meas_time_ms must be smaller or equal to 10000 ms")
    if (analog_meas_time_ms < 0.02) or (analog_meas_time_ms > 65.536):
        raise ValueError(f"analog_meas_time_ms must be in the range [0.02, 65.536] ms")
    if (repeats < 1) or (repeats > 682):
        raise ValueError(f"repeats must be in the range [1, 682]")
    if (repeats > 1) and (interval_ms < meas_time_ms):
        raise ValueError(f"interval_ms must be greater or equal to meas_time_ms")
    if (repeats > 1) and (interval_ms < (analog_meas_time_ms + analog_dead_time_ms)):
        raise ValueError(f"interval_ms must be greater or equal to (analog_meas_time_ms + analog_dead_time_ms)")
    if (interval_ms > 43980465):
        raise ValueError(f"interval_ms must be smaller or equal to 43980465 ms")

    full_reset_delay = 40000    # 400 us
    conversion_delay = 1200     #  12 us
    range_switch_delay = 25     # 250 ns
    reset_switch_delay = 2000   #  20 us
    input_gate_delay = 100      #   1 us
    pre_cnt_window = 100        #   1 us
    fixed_range = 2000          #  20 us

    dead_time = full_reset_delay + 2 * conversion_delay + range_switch_delay + reset_switch_delay + input_gate_delay + pre_cnt_window + fixed_range  # 466.25 us

    if (analog_dead_time_ms < (dead_time / 1e5)) or (analog_dead_time_ms > 671.08864):
        raise ValueError(f"analog_dead_time_ms must be in the range [{dead_time / 1e5:.5f}, 671.08864] ms")

    # Add an additional delay to reach the defined effective analog dead time.
    analog_window_delay = round(analog_dead_time_ms * 1e5) - dead_time

    analog_window_us = round(analog_meas_time_ms * 1000)
    analog_window_count = flash_lum_get_analog_window_count(meas_time_ms, analog_meas_time_ms, analog_dead_time_ms)

    interval_delay = round((interval_ms - (analog_meas_time_ms + analog_dead_time_ms) * analog_window_count) * 1e5)
    interval_delay_coarse, interval_delay_fine = divmod(interval_delay, 2**26)

    seq_gen = meas_seq_generator()

    # results = [pmt1_cnt, pmt1_al, pmt1_ah, pmt2_cnt, pmt2_al, pmt2_ah, ...]
    seq_gen.SetAddrReg(relative=False, dataNotAddrSrc=False, sign=False, stackNotRegSrc=False, srcReg=0, dstReg=0, addr=0)
    seq_gen.Loop(repeats * 6)
    seq_gen.ClearResultBuffer(relative=True, dword=False, addrReg=0, addr=0)
    seq_gen.SetAddrReg(relative=True, dataNotAddrSrc=False, sign=False, stackNotRegSrc=False, srcReg=0, dstReg=0, addr=1)
    seq_gen.LoopEnd()

    seq_gen.SetAddrReg(relative=False, dataNotAddrSrc=False, sign=False, stackNotRegSrc=False, srcReg=0, dstReg=0, addr=0)

    seq_gen.ResetSignals(OutputSignal.InputGatePMT1 | OutputSignal.InputGatePMT2)
    seq_gen.ResetSignals(OutputSignal.HVGatePMT1 | OutputSignal.HVGatePMT2)
    # seq_gen.SetSignals(OutputSignal.HVGatePMT1 | OutputSignal.HVGatePMT2)

    seq_gen.Loop(repeats)

    seq_gen.Loop(analog_window_count)
    seq_gen.SetAnalogControl(pmt1=AnalogControlMode.full_offset_reset, pmt2=AnalogControlMode.full_offset_reset)

    seq_gen.TimerWaitAndRestart(full_reset_delay)
    seq_gen.SetIntegratorMode(pmt1=IntegratorMode.full_reset, pmt2=IntegratorMode.full_reset)

    seq_gen.TimerWaitAndRestart(conversion_delay)
    seq_gen.SetTriggerOutput(TriggerSignal.SamplePMT1 | TriggerSignal.SamplePMT2)

    seq_gen.TimerWaitAndRestart(range_switch_delay)
    seq_gen.SetIntegratorMode(pmt1=IntegratorMode.low_range_reset, pmt2=IntegratorMode.low_range_reset)
    seq_gen.SetAnalogControl(pmt1=AnalogControlMode.read_offset, pmt2=AnalogControlMode.read_offset)

    seq_gen.TimerWaitAndRestart(reset_switch_delay)
    seq_gen.SetTriggerOutput(TriggerSignal.SamplePMT1 | TriggerSignal.SamplePMT2)
    seq_gen.SetIntegratorMode(pmt1=IntegratorMode.integrate_autorange, pmt2=IntegratorMode.integrate_autorange)

    seq_gen.SetSignals(OutputSignal.HVGatePMT1 | OutputSignal.HVGatePMT2)

    seq_gen.TimerWaitAndRestart(input_gate_delay)
    seq_gen.SetSignals(OutputSignal.InputGatePMT1 | OutputSignal.InputGatePMT2)

    seq_gen.TimerWaitAndRestart(pre_cnt_window)
    seq_gen.PulseCounterControl(MeasurementChannel.PMT1, cumulative=False, resetCounter=True, resetPresetCounter=True, correctionOn=False)
    seq_gen.PulseCounterControl(MeasurementChannel.PMT2, cumulative=False, resetCounter=True, resetPresetCounter=True, correctionOn=False)
    seq_gen.Loop(analog_window_us)
    seq_gen.TimerWaitAndRestart(pre_cnt_window)
    seq_gen.PulseCounterControl(MeasurementChannel.PMT1, cumulative=True, resetCounter=False, resetPresetCounter=True, correctionOn=True)
    seq_gen.PulseCounterControl(MeasurementChannel.PMT2, cumulative=True, resetCounter=False, resetPresetCounter=True, correctionOn=True)
    seq_gen.LoopEnd()

    seq_gen.ResetSignals(OutputSignal.InputGatePMT1 | OutputSignal.InputGatePMT2)
    seq_gen.ResetSignals(OutputSignal.HVGatePMT1 | OutputSignal.HVGatePMT2)

    seq_gen.SetAnalogControl(pmt1=AnalogControlMode.read_offset, pmt2=AnalogControlMode.read_offset)

    seq_gen.TimerWaitAndRestart(fixed_range)
    seq_gen.SetIntegratorMode(pmt1=IntegratorMode.integrate_with_fixed_range, pmt2=IntegratorMode.integrate_with_fixed_range)
    
    seq_gen.TimerWaitAndRestart(conversion_delay)
    seq_gen.SetTriggerOutput(TriggerSignal.SamplePMT1 | TriggerSignal.SamplePMT2)
    seq_gen.GetPulseCounterResult(MeasurementChannel.PMT1, relative=True, resetCounter=True, cumulative=True, dword=False, addrPos=0, resultPos=0)
    seq_gen.GetAnalogResult(MeasurementChannel.PMT1, isRelativeAddr=True, ignoreRange=False, isHiRange=False, addResult=True, dword=False, addrPos=0, resultPos=1)
    seq_gen.GetAnalogResult(MeasurementChannel.PMT1, isRelativeAddr=True, ignoreRange=False, isHiRange=True, addResult=True, dword=False, addrPos=0, resultPos=2)
    seq_gen.GetPulseCounterResult(MeasurementChannel.PMT2, relative=True, resetCounter=True, cumulative=True, dword=False, addrPos=0, resultPos=3)
    seq_gen.GetAnalogResult(MeasurementChannel.PMT2, isRelativeAddr=True, ignoreRange=False, isHiRange=False, addResult=True, dword=False, addrPos=0, resultPos=4)
    seq_gen.GetAnalogResult(MeasurementChannel.PMT2, isRelativeAddr=True, ignoreRange=False, isHiRange=True, addResult=True, dword=False, addrPos=0, resultPos=5)

    if analog_window_delay > 0:
        seq_gen.TimerWaitAndRestart(analog_window_delay)
    seq_gen.LoopEnd()

    seq_gen.SetAddrReg(relative=True, dataNotAddrSrc=False, sign=False, stackNotRegSrc=False, srcReg=0, dstReg=0, addr=6)

    if interval_delay_coarse > 0:
        seq_gen.Loop(interval_delay_coarse)
        seq_gen.TimerWaitAndRestart(2**26)
        seq_gen.LoopEnd()
    if interval_delay_fine > 0:
        seq_gen.TimerWaitAndRestart(interval_delay_fine)
    seq_gen.LoopEnd()

    seq_gen.ResetSignals(OutputSignal.HVGatePMT1 | OutputSignal.HVGatePMT2)
    seq_gen.SetAnalogControl(pmt1=AnalogControlMode.full_offset_reset, pmt2=AnalogControlMode.full_offset_reset)
    seq_gen.SetIntegratorMode(pmt1=IntegratorMode.full_reset, pmt2=IntegratorMode.full_reset)
    seq_gen.Stop(0)

    return seq_gen.currSequence


# Flash Lum Utility Functions ##########################################################################################

def flash_lum_counting_to_cps(pmt, cnt, meas_time_ms, analog_meas_time_ms, analog_dead_time_ms):
    if pmt == 'pmt1':
        ppr = meas_unit.get_config(meas_config.GC_Params.Pulse_pair_res_PMT1_s)
    if pmt == 'pmt2':
        ppr = meas_unit.get_config(meas_config.GC_Params.Pulse_pair_res_PMT2_s)

    analog_window_count = flash_lum_get_analog_window_count(meas_time_ms, analog_meas_time_ms, analog_dead_time_ms)
    cps = cnt / (analog_meas_time_ms / 1000 * analog_window_count)
    cps = cps / (1 - cps * ppr)

    return cps


def flash_lum_analog_to_cps(pmt, al, ah, meas_time_ms, analog_meas_time_ms, analog_dead_time_ms):
    if pmt == 'pmt1':
        als = meas_unit.get_config(meas_config.GC_Params.AnalogCountingEquivalent_PMT1)
        ahs = meas_unit.get_config(meas_config.GC_Params.AnalogHighRangeScale_PMT1)
    if pmt == 'pmt2':
        als = meas_unit.get_config(meas_config.GC_Params.AnalogCountingEquivalent_PMT2)
        ahs = meas_unit.get_config(meas_config.GC_Params.AnalogHighRangeScale_PMT2)

    analog_window_count = flash_lum_get_analog_window_count(meas_time_ms, analog_meas_time_ms, analog_dead_time_ms)
    cps = (al * als + ah * ahs * als) / (analog_meas_time_ms / 1000 * analog_window_count)

    return cps


def flash_lum_get_analog_window_count(meas_time_ms, analog_meas_time_ms, analog_dead_time_ms):
    if meas_time_ms > (analog_meas_time_ms + analog_dead_time_ms):
        return int(meas_time_ms / (analog_meas_time_ms + analog_dead_time_ms))
    else:
        return 1


# Flash Lum Bench Tests ################################################################################################

async def flash_lum_test(meas_time_ms=1000, analog_meas_time_ms=0.1, analog_dead_time_ms=0.5, led_source='smu', led_channel='led1', led_current=200):

    from predefined_tasks.pmt_adjust import set_led_current

    await send_to_gc(f"Starting Firmware")
    await asyncio.gather(
        fmb_unit.StartFirmware(),
        eef_unit.StartFirmware(),
    )

    await pmt1_cooling.InitializeDevice()
    await pmt1_cooling.set_target_temperature(18.0)
    await pmt1_cooling.enable()

    await meas_unit.InitializeDevice()
    await meas_unit.Set_PMT_HV(True)
    await set_led_current(led_current, source=led_source, channel=led_channel, led_type='green')

    await asyncio.sleep(1.0)

    op_id = 'flash_lum'
    meas_unit.ClearOperations()
    await load_flash_lum_operation(op_id, meas_time_ms, analog_meas_time_ms, analog_dead_time_ms)
    
    await meas_unit.ExecuteMeasurement(op_id)
    await asyncio.sleep(meas_time_ms / 1000 * 0.8)
    hv = (await eef_unit.GetAnalogInput(EEFAnalogInput.PMT1HIGHVOLTAGEMONITOR))[0] * 1500
    results = await meas_unit.ReadMeasurementValues(op_id)

    await send_to_gc(f" ")
    await send_to_gc(f"pmt1_cnt: {results[0]} ; pmt1_al: {results[1]} ; pmt1_ah: {results[2]} ; hv: {hv}", log=True)

    pmt1_counting_cps = flash_lum_counting_to_cps('pmt1', results[0], meas_time_ms, analog_meas_time_ms, analog_dead_time_ms)
    pmt1_analog_cps = flash_lum_analog_to_cps('pmt1', results[1], results[2], meas_time_ms, analog_meas_time_ms, analog_dead_time_ms)

    await send_to_gc(f" ")
    await send_to_gc(f"pmt1_counting_cps : {pmt1_counting_cps:9.0f}", log=True)
    await send_to_gc(f"pmt1_analog_cps   : {pmt1_analog_cps:9.0f}", log=True)
    
    await meas_unit.Set_PMT_HV(False)
    await set_led_current(0, source=led_source, channel=led_channel, led_type='green')

    await send_to_gc(f" ")
    return f"flash_lum_test done"


async def flash_lum_signal_scan(meas_time_ms=1000, analog_meas_time_ms=0.1, analog_dead_time_ms=0.5, led_source='smu', led_channel='led1', led_current_start=10, led_current_stop=200, led_current_step=10):

    from predefined_tasks.pmt_adjust import set_led_current

    GlobalVar.set_stop_gc(False)

    await send_to_gc(f"Starting Firmware")
    await asyncio.gather(
        fmb_unit.StartFirmware(),
        eef_unit.StartFirmware(),
    )

    await pmt1_cooling.InitializeDevice()
    await pmt1_cooling.set_target_temperature(18.0)
    await pmt1_cooling.enable()

    await meas_unit.InitializeDevice()
    await meas_unit.Set_PMT_HV(True)

    await asyncio.sleep(1.0)

    op_id = 'flash_lum'
    meas_unit.ClearOperations()
    await load_flash_lum_operation(op_id, meas_time_ms, analog_meas_time_ms, analog_dead_time_ms)

    await send_to_gc(f" ")
    await send_to_gc(f"current ; power   ; pmt1_cnt   ; pmt1_al    ; pmt1_ah    ; pmt1_counting ; pmt1_analog   ; hv   ;", log=True)

    for led_current in range(led_current_start, (led_current_stop + led_current_step), led_current_step):
            
        if GlobalVar.get_stop_gc():
            return f"flash_lum_signal_scan stopped by user"

        await set_led_current(led_current, source=led_source, channel=led_channel, led_type='green')
        await asyncio.sleep(0.5)
    
        await meas_unit.ExecuteMeasurement(op_id)
        await asyncio.sleep(meas_time_ms / 1000 * 0.8)
        hv = (await eef_unit.GetAnalogInput(EEFAnalogInput.PMT1HIGHVOLTAGEMONITOR))[0] * 1500
        results = await meas_unit.ReadMeasurementValues(op_id)

        pmt1_counting_cps = flash_lum_counting_to_cps('pmt1', results[0], meas_time_ms, analog_meas_time_ms, analog_dead_time_ms)
        pmt1_analog_cps = flash_lum_analog_to_cps('pmt1', results[1], results[2], meas_time_ms, analog_meas_time_ms, analog_dead_time_ms)

        await send_to_gc(f"{led_current:7d} ; {flash_lum_get_led_power(led_current):7.2f} ; {results[0]:10d} ; {results[1]:10d} ; {results[2]:10d} ; {pmt1_counting_cps:13.0f} ; {pmt1_analog_cps:13.0f} ; {hv:4.0f} ;", log=True)
    
    await meas_unit.Set_PMT_HV(False)
    await set_led_current(0, source=led_source, channel=led_channel, led_type='green')

    await send_to_gc(f" ")
    return f"flash_lum_signal_scan done"


async def flash_lum_dead_time_scan(meas_time_ms=1000, analog_meas_time_ms=0.1, led_source='smu', led_channel='led1', led_current=110, dt_start=0.5, dt_stop=2.0, dt_step=0.1):

    from predefined_tasks.pmt_adjust import set_led_current

    GlobalVar.set_stop_gc(False)

    await send_to_gc(f"Starting Firmware")
    await asyncio.gather(
        fmb_unit.StartFirmware(),
        eef_unit.StartFirmware(),
    )

    await pmt1_cooling.InitializeDevice()
    await pmt1_cooling.set_target_temperature(18.0)
    await pmt1_cooling.enable()

    await meas_unit.InitializeDevice()
    await meas_unit.Set_PMT_HV(True)

    await asyncio.sleep(1.0)
    
    await set_led_current(led_current, source=led_source, channel=led_channel, led_type='green')

    await send_to_gc(f" ")
    await send_to_gc(f"dt   ; pmt1_cnt   ; pmt1_al    ; pmt1_ah    ; pmt1_counting ; pmt1_analog   ; hv   ;", log=True)

    dt_range = [dt / 1e6 for dt in range(round(dt_start * 1e6), round((dt_stop + dt_step) * 1e6), round(dt_step * 1e6))]

    for analog_dead_time_ms in dt_range:
            
        if GlobalVar.get_stop_gc():
            return f"flash_lum_dead_time_scan stopped by user"

        op_id = 'flash_lum'
        meas_unit.ClearOperations()
        await load_flash_lum_operation(op_id, meas_time_ms, analog_meas_time_ms, analog_dead_time_ms)
    
        await meas_unit.ExecuteMeasurement(op_id)
        await asyncio.sleep(meas_time_ms / 1000 * 0.8)
        hv = (await eef_unit.GetAnalogInput(EEFAnalogInput.PMT1HIGHVOLTAGEMONITOR))[0] * 1500
        results = await meas_unit.ReadMeasurementValues(op_id)

        pmt1_counting_cps = flash_lum_counting_to_cps('pmt1', results[0], meas_time_ms, analog_meas_time_ms, analog_dead_time_ms)
        pmt1_analog_cps = flash_lum_analog_to_cps('pmt1', results[1], results[2], meas_time_ms, analog_meas_time_ms, analog_dead_time_ms)

        await send_to_gc(f"{analog_dead_time_ms:4.2f} ; {results[0]:10d} ; {results[1]:10d} ; {results[2]:10d} ; {pmt1_counting_cps:13.0f} ; {pmt1_analog_cps:13.0f} ; {hv:4.0f} ;", log=True)

        await asyncio.sleep(0.5)
    
    await meas_unit.Set_PMT_HV(False)
    await set_led_current(0, source=led_source, channel=led_channel, led_type='green')

    await send_to_gc(f" ")
    return f"flash_lum_dead_time_scan done"


async def flash_lum_kinetic(repeats=50, interval_ms=100, meas_time_ms=100, analog_meas_time_ms=0.1, analog_dead_time_ms=0.5, led_source='fmb', led_channel='led1', led_current=64):

    from predefined_tasks.pmt_adjust import set_led_current

    GlobalVar.set_stop_gc(False)

    await send_to_gc(f"Starting Firmware")
    await asyncio.gather(
        fmb_unit.StartFirmware(),
        eef_unit.StartFirmware(),
    )

    await pmt1_cooling.InitializeDevice()
    await pmt1_cooling.set_target_temperature(18.0)
    await pmt1_cooling.enable()

    await meas_unit.InitializeDevice()
    await meas_unit.Set_PMT_HV(True)

    await asyncio.sleep(1.0)

    op_id = 'flash_lum'
    meas_unit.ClearOperations()
    await load_flash_lum_operation(op_id, meas_time_ms, analog_meas_time_ms, analog_dead_time_ms, repeats, interval_ms)

    await send_to_gc(f" ")
    await send_to_gc(f"time [s] ; pmt1_cnt   ; pmt1_al    ; pmt1_ah    ; pmt1_counting ; pmt1_analog   ;", log=True)
    
    await asyncio.gather(
        meas_unit.ExecuteMeasurement(op_id),
        flash_lum_led_pulse(led_current, source=led_source, channel=led_channel, led_type='green'),
    )
    results = await meas_unit.ReadMeasurementValues(op_id)

    for i in range(repeats):
        time = i * interval_ms / 1000
        pmt1_cnt = results[i * 6 + 0]
        pmt1_al  = results[i * 6 + 1]
        pmt1_ah  = results[i * 6 + 2]
        pmt1_counting_cps = flash_lum_counting_to_cps('pmt1', pmt1_cnt, meas_time_ms, analog_meas_time_ms, analog_dead_time_ms)
        pmt1_analog_cps = flash_lum_analog_to_cps('pmt1', pmt1_al, pmt1_ah, meas_time_ms, analog_meas_time_ms, analog_dead_time_ms)

        await send_to_gc(f"{time:7.3f} ; {pmt1_cnt:10d} ; {pmt1_al:10d} ; {pmt1_ah:10d} ; {pmt1_counting_cps:13.0f} ; {pmt1_analog_cps:13.0f} ;", log=True)
    
    await meas_unit.Set_PMT_HV(False)

    await send_to_gc(f" ")
    return f"flash_lum_signal_scan done"


async def flash_lum_led_pulse(led_current=100, source='smu', channel='led1', led_type='green'):
    from predefined_tasks.pmt_adjust import set_led_current
    pulse = [0.0,0.1,0.2,0.4,0.8,1.0,0.8,0.6,0.5,0.4,0.3,0.25,0.2,0.15,0.1,0.075,0.05,0.025,0.01,0.005,0.0]
    for scaling in pulse:
        current = round(led_current * scaling)
        await set_led_current(current, source, channel, led_type)
    pass


def flash_lum_get_led_power(current):
    from scipy.interpolate import CubicSpline
    led_current = [0,10,20,30,40,50,60,70,80,90,100,110,120,130,140,150,160,170,180,190,200]
    led_power   = [0,3.63,15.47,36.3,65.8,103.3,148.4,200.4,258.7,323.0,392.0,467.0,546.0,630.0,717.0,809.0,904.0,1002.0,1104.0,1209.0,1316.0]
    spl = CubicSpline(led_current, led_power)
    return spl(current)

